function [info] = adahspsvrg_fp_mixed_l1l2_solver(pname,fun,lambda,parms)
%=========================================================================
% Purpose : AdaHSPG+ solver for 
%           minimizing logistic function plus mixed l1/l2 regularization.
%=========================================================================
% Input:
%   fun     object of type LogRegCost (see function LogRegCost.m)
%   lambda  weighting parameter in objective function: f(x) + lambda|x|_1 
%   parms   structure of control parameters (see function fPlusL1_spec.m)
%
% Output:
%   info    structure holding information parameters
%=========================================================================

% Get control parameters.
max_epoch     = parms.max_epoch;
batch_size_small = parms.batch_size_small;
batch_size_large = parms.batch_size_large;
num_groups    = parms.num_groups;
epsilon       = parms.epsilon;
max_steps     = parms.max_steps;
num_switch    = 0;
eps_adapt     = parms.eps_adapt;
% eps_adapt_stride = parms.eps_adapt_stride;
eps_adapt_coeff = parms.eps_adapt_coeff;

% Size of the problem (the number of optimizaton variables)
n             = fun.num_features;
m             = fun.num_samples;
batch_size_small    = floor(max(batch_size_small, 1));
batch_size_large    = floor(max(batch_size_large, 1));
group_indexes = utils.split_groups(n, num_groups);

max_steps     = ceil(max_epoch * m / batch_size_small);
record_freq   = ceil(max_steps / max_epoch) * 10;

num_records   = 0;
num_steps     = 1;
% Initial estimate of a solution along with function values.
x             = zeros(n,1);

indexes       = 1:m;
epoch         = 1;
alpha         = parms.alpha;
gamma         = parms.gamma;
kappa         = parms.kappa;
mu            = parms.mu;
T             = ceil(parms.T * max_steps / max_epoch);
alg_start_time = tic;
info.Fval     = [];
info.fval     = [];
info.Omegaval = [];
info.nnz      = [];
info.group_sparsity = [];
%========================
% Begin: main while loop.
%========================
do_prox_sg = true;

header = '   Step  | psi_value  f_value  Omega  |   gs   |  norm_opt |  eps |  step type |\n';
hline  = '--------------------------------------------------------------------------------\n';
prev_g_sparsity = 0.0;
g_sparsity = 0.0;
prev_epsilon = 0.0;
eps_increase_count = 0;
eps_decrease_count = 0;

while(1)      
    if mod(num_steps, record_freq) == 1
        % Check progess
        fun.setExpterm(x, indexes);
        fun.setSigmoid();
        f = fun.func(indexes); 
        norm_l1l2_x = utils.mixed_l1_l2_norm(x, group_indexes);
        F = f + lambda * norm_l1l2_x;
        grad_f_full = fun.grad(indexes);
        prox_delta = prox_mapping_group_delta(x, grad_f_full, lambda, alpha, group_indexes);
        norm_opt = norm(prox_delta / alpha);
        nnz  = sum(x~=0);
        g_sparsity = utils.compute_group_sparsity(x, group_indexes);
        num_records = num_records + 1;
        if mod(num_records, 5) == 1
            fprintf(hline);
            fprintf(header);
        end
        if do_prox_sg
            fprintf('   %4d  |  %.4f   %.4f   %.4f  | %.4f |   %.4f  | %.2f |   ProxSG   |\n', num_steps, F, f, norm_l1l2_x, g_sparsity, norm_opt, epsilon);
        else
            fprintf('   %4d  |  %.4f   %.4f   %.4f  | %.4f |   %.4f  | %.2f |  HalfSpace |\n', num_steps, F, f, norm_l1l2_x, g_sparsity, norm_opt, epsilon);
        end
        
    end
 

    if num_records <= 2
        info.Fval = [info.Fval F];
        info.fval = [info.fval f];
        info.Omegaval = [info.Omegaval norm_l1l2_x];
        info.nnz  = [info.nnz nnz];
        info.group_sparsity = [info.group_sparsity g_sparsity];
    else
        info.Fval = [info.Fval(2:2) F];
        info.fval = [info.fval(2:2) f];
        info.Omegaval = [info.Omegaval(2:2) norm_l1l2_x];
        info.nnz  = [info.nnz(2:2) nnz];
        info.group_sparsity = [info.group_sparsity(2:2) g_sparsity];
    end
    
    % Termination condition:
    if num_steps > max_steps
        info.status = 0;
        info.Fval = mean(info.Fval);
        info.fval = mean(info.fval);
        info.Omegaval = mean(info.Omegaval);
        info.density = mean(info.nnz) / n;
        info.sparsity = 1.0 - info.density;
        info.group_sparsity = mean(info.group_sparsity);
        alg_end_time = toc(alg_start_time);
        info.runtime = alg_end_time;
        info.eps_decrease_count = eps_decrease_count;
        info.eps_increase_count = eps_increase_count;
        info.epsilon = epsilon;
        info.x = x;
        fprintf('Maximum epoch has been reached. Run time %f\n',alg_end_time);
        break
    end
    
    % sample a large batch
    indexes_batch_large = sort(randsample(m, batch_size_large));
    fun.setExpterm(x, indexes_batch_large);
    fun.setSigmoid();
    grad_f_batch_large = fun.grad(indexes_batch_large);
    

    
    % switch
    if num_switch == 0
        do_prox_sg = true;
        num_switch = num_switch + 1;
    elseif num_steps / T > num_switch
        do_prox_sg = switch_adaptive(x, grad_f_batch_large, mu, lambda, alpha, kappa, group_indexes);
        num_switch = num_switch + 1;
        % adapt epsilon 
        if eps_adapt == 1 && ~do_prox_sg
            prox_delta = prox_mapping_group_delta(x, grad_f_batch_large, lambda, alpha, group_indexes);
            norm_opt_large = norm(prox_delta / alpha);  
            % fprintf("num_steps: %d, norm_opt_large: %.4f\n", num_steps, norm_opt_large);
            g_sparsity = utils.compute_group_sparsity(x, group_indexes);
            if norm_opt_large < 1e-1 && prev_g_sparsity == g_sparsity
                prev_epsilon = epsilon;
                % epsilon = epsilon + eps_adapt_stride;
                epsilon = max(epsilon, 0.1) * eps_adapt_coeff;
                epsilon = min(epsilon, 0.999);
                eps_increase_count = eps_increase_count + 1;
            elseif norm_opt_large > 1e-1
                epsilon = prev_epsilon;
                epsilon = epsilon / eps_adapt_coeff;
                epsilon = max(epsilon, 0.0);
                eps_decrease_count = eps_decrease_count + 1;
            end
            prev_g_sparsity = g_sparsity;
        end
    end

    shuffled_indexes = utils.indexes_shuffle(indexes);
    num_batches = length(shuffled_indexes) / batch_size_small;
    
    if ~do_prox_sg
    % if half-space performs
        trial_x = x;
        grad_psi = subgrad_psi(x, grad_f_batch_large, lambda, group_indexes);
        prox_delta = prox_mapping_group_delta(x, grad_f_batch_large, lambda, alpha, group_indexes);
        for j = 1 : length(group_indexes)
            group = group_indexes{j};
            lhs = norm(x(group));
            rhs = kappa * norm(grad_psi(group));
            prox_iter_norm = norm(x(group) + prox_delta(group));
            if lhs > 0 && lhs >= rhs && prox_iter_norm ~= 0
                trial_x(group) = x(group) - alpha * grad_psi(group);
                if dot(trial_x(group), x(group)) < (epsilon * norm(x(group)) ^ 2)
                    x(group) = 0;
                else
                    x(group) = trial_x(group);
                end
            end
        end
        num_steps = num_steps + 1;
    else
    % if proxsvrg performs
        x_tilde = x;
        for i = 1 : num_batches
            start_idx = 1 + (i-1) * batch_size_small;
            end_idx = min( i * batch_size_small, length(shuffled_indexes) );
            minibatch_idxes = shuffled_indexes(start_idx:end_idx); 
            
            % Calculate grad_f_x_tilde
            fun.setExpterm(x_tilde, minibatch_idxes);
            fun.setSigmoid();
            grad_f_x_tilde = fun.grad(minibatch_idxes);    
 
            % Calculate grad_f_x
            fun.setExpterm(x, minibatch_idxes);
            fun.setSigmoid();
            grad_f_x = fun.grad(minibatch_idxes);
            
            v = grad_f_x_tilde - grad_f_x + grad_f_batch_large;
            prox_delta = prox_mapping_group_delta(x_tilde, v, lambda, alpha, group_indexes);
            x_tilde = x_tilde + prox_delta;
            num_steps = num_steps + 1;
        end
        x = x_tilde;
    end

    alpha = alpha * gamma;
    epoch = epoch + 1;
end
%========================
% END: main while loop.
%========================


end

%% Proximal Mapping for mixed l1/l2
function prox_delta = prox_mapping_group_delta(x, grad_f, lambda, alpha, group_indexes)

new_x = zeros(size(x));

numer = alpha * lambda;
trial_x = x - alpha * grad_f;
for i = 1 : length(group_indexes)
    group = group_indexes{i};
    denom = norm(trial_x(group));
    coeff = max(0.0, 1.0 - numer/(denom+1e-6));
    new_x(group) = coeff * trial_x(group);
end
prox_delta = new_x - x;
end

%% Proximal Mapping for mixed l1/l2
function new_x = prox_mapping_group(x, grad_f, lambda, alpha, group_indexes)

new_x = zeros(size(x));

numer = alpha * lambda;
trial_x = x - alpha * grad_f;
for i = 1 : length(group_indexes)
    group = group_indexes{i};
    denom = norm(trial_x(group));
    coeff = max(0.0, 1.0 - numer/(denom+1e-6));
    new_x(group) = coeff * trial_x(group);
end

end


%% Subgradient calculation of psi
function subgrad = subgrad_psi(x, grad_f, lambda, group_indexes)

subgrad = grad_f;

for i = 1 : length(group_indexes)
    group = group_indexes{i};
    subgrad(group) = subgrad(group) + lambda * x(group) / (norm(x(group)) + 1e-6);
end 

end

%% Switch
function do_prox_sg = switch_adaptive(x, grad_f_full, mu, lambda, alpha, kappa, group_indexes)

norm_psg = 0.0;
norm_hs = 0.0;

prox_delta = prox_mapping_group_delta(x, grad_f_full, lambda, alpha, group_indexes);
grad_psi = subgrad_psi(x, grad_f_full, lambda, group_indexes);

for j = 1 : length(group_indexes)
    group = group_indexes{j};
    lhs = norm(x(group));
    rhs = kappa * norm(grad_psi(group));
    prox_iter_norm = norm(x(group) + prox_delta(group));
    if lhs > 0 && lhs >= rhs && prox_iter_norm ~= 0
        norm_hs = norm_hs + norm(prox_delta(group));
    else
        norm_psg = norm_psg + norm(prox_delta(group));
    end
end 

if norm_psg <= mu * norm_hs
    do_prox_sg = false;
else
    do_prox_sg = true;
end
end

